Open In Colab

Exercise 1: Physically-Constrained Post-Processing¶

⛈ Welcome to the first exercise of the "Physics-Guided Machine Learning" e-learning module of ECMWF's MOOC on Machine Learning in Weather and Climate 🌤

By the end of this first exercise, you will:

  1. Understand how constraining a machine learning model's output using physical knowledge can make it more consistent and trustworthy,
  2. Know how to use custom layers to enforce general nonlinear constraints within a neural network, and
  3. Practice what you learned about post-processing in Tier 1 on a real-world application case.

While this notebook's completion time may widely vary depending on your programming experience, we estimate it will average to 30 minutes for this MOOC's students. This notebook provides a minimal reproducible example of the work described in the following preprint:

Zanetta, Francesco, Daniele Nerini, Tom Beucler, and Mark A. Liniger. "Physics-constrained deep learning postprocessing of temperature and humidity.",

and relies on the more general physically-constrained framework described in the following preprint:

Beucler, T., M. Pritchard, S. Rasp, P. Gentine, J. Ott, and P. Baldi. "Enforcing analytic constraints in neural-networks emulating physical systems."

We provide an anonymized sample of our data. This pedagogical exercise would not have been possible without the source code and contribution of Francesco Zanetta (MeteoSwiss, ETH).

We will be relying on PyTorch, whose documentation you can find here, and the notebooks assume that you will run them on Google Colab (Google Colab tutorial at this link).

While everything can be run locally and there are only a handful of lines that use Google specific libraries, we encourage beginners to use Google Colab not to run into Python virtual environment issues.

Before we get started, if you are struggling with some of the exercises, do not hesitate to:

  • Use a direct Internet search, or stackoverflow
  • Debug your program, e.g. by following this tutorial
  • Use assertions, e.g. by following this tutorial
  • Ask for help on the MOOC's Moodle Forum

Marcell_Faber_Stormy_Rainy_View_Locarno-min.jpeg

🌧 A storm is quickly approaching the MeteoSwiss agency in Locarno! Will you be able to post-process the weather it brings without violating the laws of thermodynamics? 🌡

Source: Photo by Marcell Faber licensed under the Adobe Stock standard license

Part I. Configuration and requirements¶

In [1]:
#@title  Run this cell for preliminary requirements. Double click for the source code
!sudo apt-get update -y
!sudo apt-get install python3.10
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2
!curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
!python3 get-pip.py --force-reinstall
!python3 -m pip install ipython ipython_genutils ipykernel jupyter_console prompt_toolkit httplib2 astor
!ln -s /usr/local/lib/python3.7/dist-packages/google \
       /usr/local/lib/python3.10/dist-packages/google

!pip install numpy torch xarray==2022.10.0 netcdf4
Get:1 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease [3,622 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease [1,581 B]
Get:3 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ Packages [74.2 kB]
Get:4 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease [18.1 kB]
Get:5 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  Packages [923 kB]
Get:6 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Hit:7 http://archive.ubuntu.com/ubuntu focal InRelease
Get:8 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Hit:9 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease
Hit:10 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Get:11 http://security.ubuntu.com/ubuntu focal-security/main amd64 Packages [2,545 kB]
Get:12 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Hit:13 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease
Get:14 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages [3,026 kB]
Hit:15 http://ppa.launchpad.net/ubuntugis/ppa/ubuntu focal InRelease
Get:16 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal/main Sources [2,403 kB]
Get:17 http://security.ubuntu.com/ubuntu focal-security/universe amd64 Packages [1,019 kB]
Get:18 http://archive.ubuntu.com/ubuntu focal-updates/restricted amd64 Packages [2,141 kB]
Get:19 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 Packages [1,314 kB]
Get:20 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal/main amd64 Packages [1,138 kB]
Fetched 14.9 MB in 5s (2,916 kB/s)
Reading package lists... Done
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  libpython3.10-minimal libpython3.10-stdlib python3.10-minimal
Suggested packages:
  python3.10-venv binfmt-support
The following NEW packages will be installed:
  libpython3.10-minimal libpython3.10-stdlib python3.10 python3.10-minimal
0 upgraded, 4 newly installed, 0 to remove and 41 not upgraded.
Need to get 5,225 kB of archives.
After this operation, 20.2 MB of additional disk space will be used.
Get:1 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal/main amd64 libpython3.10-minimal amd64 3.10.10-1+focal1 [822 kB]
Get:2 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal/main amd64 python3.10-minimal amd64 3.10.10-1+focal1 [2,093 kB]
Get:3 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal/main amd64 libpython3.10-stdlib amd64 3.10.10-1+focal1 [1,759 kB]
Get:4 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal/main amd64 python3.10 amd64 3.10.10-1+focal1 [551 kB]
Fetched 5,225 kB in 6s (949 kB/s)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 4.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 
Selecting previously unselected package libpython3.10-minimal:amd64.
(Reading database ... 128276 files and directories currently installed.)
Preparing to unpack .../libpython3.10-minimal_3.10.10-1+focal1_amd64.deb ...
Unpacking libpython3.10-minimal:amd64 (3.10.10-1+focal1) ...
Selecting previously unselected package python3.10-minimal.
Preparing to unpack .../python3.10-minimal_3.10.10-1+focal1_amd64.deb ...
Unpacking python3.10-minimal (3.10.10-1+focal1) ...
Selecting previously unselected package libpython3.10-stdlib:amd64.
Preparing to unpack .../libpython3.10-stdlib_3.10.10-1+focal1_amd64.deb ...
Unpacking libpython3.10-stdlib:amd64 (3.10.10-1+focal1) ...
Selecting previously unselected package python3.10.
Preparing to unpack .../python3.10_3.10.10-1+focal1_amd64.deb ...
Unpacking python3.10 (3.10.10-1+focal1) ...
Setting up libpython3.10-minimal:amd64 (3.10.10-1+focal1) ...
Setting up python3.10-minimal (3.10.10-1+focal1) ...
Setting up libpython3.10-stdlib:amd64 (3.10.10-1+focal1) ...
Setting up python3.10 (3.10.10-1+focal1) ...
Processing triggers for mime-support (3.64ubuntu1) ...
Processing triggers for man-db (2.9.1-1) ...
update-alternatives: error: alternative path /usr/bin/python3.7 doesn't exist
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2513k  100 2513k    0     0  49.1M      0 --:--:-- --:--:-- --:--:-- 49.1M
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-23.0.1-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 30.7 MB/s eta 0:00:00
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 22.0.4
    Uninstalling pip-22.0.4:
      Successfully uninstalled pip-22.0.4
Successfully installed pip-23.0.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: ipython in /usr/local/lib/python3.9/dist-packages (7.9.0)
Requirement already satisfied: ipython_genutils in /usr/local/lib/python3.9/dist-packages (0.2.0)
Requirement already satisfied: ipykernel in /usr/local/lib/python3.9/dist-packages (5.3.4)
Requirement already satisfied: jupyter_console in /usr/local/lib/python3.9/dist-packages (6.1.0)
Requirement already satisfied: prompt_toolkit in /usr/local/lib/python3.9/dist-packages (2.0.10)
Requirement already satisfied: httplib2 in /usr/local/lib/python3.9/dist-packages (0.21.0)
Collecting astor
  Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)
Requirement already satisfied: pexpect in /usr/local/lib/python3.9/dist-packages (from ipython) (4.8.0)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.9/dist-packages (from ipython) (63.4.3)
Requirement already satisfied: decorator in /usr/local/lib/python3.9/dist-packages (from ipython) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.9/dist-packages (from ipython) (0.7.5)
Requirement already satisfied: pygments in /usr/local/lib/python3.9/dist-packages (from ipython) (2.6.1)
Requirement already satisfied: backcall in /usr/local/lib/python3.9/dist-packages (from ipython) (0.2.0)
Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.9/dist-packages (from ipython) (5.7.1)
Collecting jedi>=0.10
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 67.5 MB/s eta 0:00:00
Requirement already satisfied: tornado>=4.2 in /usr/local/lib/python3.9/dist-packages (from ipykernel) (6.2)
Requirement already satisfied: jupyter-client in /usr/local/lib/python3.9/dist-packages (from ipykernel) (6.1.12)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.9/dist-packages (from prompt_toolkit) (1.16.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.9/dist-packages (from prompt_toolkit) (0.2.6)
Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in /usr/local/lib/python3.9/dist-packages (from httplib2) (3.0.9)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.9/dist-packages (from jedi>=0.10->ipython) (0.8.3)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.9/dist-packages (from jupyter-client->ipykernel) (2.8.2)
Requirement already satisfied: jupyter-core>=4.6.0 in /usr/local/lib/python3.9/dist-packages (from jupyter-client->ipykernel) (5.3.0)
Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.9/dist-packages (from jupyter-client->ipykernel) (23.2.1)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.9/dist-packages (from pexpect->ipython) (0.7.0)
Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.9/dist-packages (from jupyter-core>=4.6.0->jupyter-client->ipykernel) (3.1.1)
Installing collected packages: jedi, astor
Successfully installed astor-0.8.1 jedi-0.18.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (1.22.4)
Requirement already satisfied: torch in /usr/local/lib/python3.9/dist-packages (1.13.1+cu116)
Collecting xarray==2022.10.0
  Downloading xarray-2022.10.0-py3-none-any.whl (947 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 947.6/947.6 kB 17.4 MB/s eta 0:00:00
Collecting netcdf4
  Downloading netCDF4-1.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.2/5.2 MB 68.5 MB/s eta 0:00:00
Requirement already satisfied: packaging>=21.0 in /usr/local/lib/python3.9/dist-packages (from xarray==2022.10.0) (23.0)
Requirement already satisfied: pandas>=1.3 in /usr/local/lib/python3.9/dist-packages (from xarray==2022.10.0) (1.4.4)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch) (4.5.0)
Collecting cftime
  Downloading cftime-1.6.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 34.2 MB/s eta 0:00:00
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.3->xarray==2022.10.0) (2022.7.1)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.3->xarray==2022.10.0) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas>=1.3->xarray==2022.10.0) (1.16.0)
Installing collected packages: cftime, netcdf4, xarray
  Attempting uninstall: xarray
    Found existing installation: xarray 2022.12.0
    Uninstalling xarray-2022.12.0:
      Successfully uninstalled xarray-2022.12.0
Successfully installed cftime-1.6.2 netcdf4-1.6.3 xarray-2022.10.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [3]:
#@title  Run this cell for Python library imports. Double click for the source code
from itertools import chain

import os
import pooch
import numpy as np 
import xarray as xr
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch import optim

torch.manual_seed(1)
np.random.seed(1)
In [4]:
#@title Run this cell to automatically save figures at the right place. Double click for the source code
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=16)
mpl.rc('xtick', labelsize=14)
mpl.rc('ytick', labelsize=14)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "postprocessing"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)
In [5]:
#@title  Run this cell to load the data using the pooch library. Double click for the source code

path_data = 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/'

# load training data 
x_path = path_data + 'EdAG3RBBgk5Kmvo54RPgT2kBp-NJqqGF6Il-gTmh9DbdeA?download=1'
y_path = path_data + 'EdVQCVKqnb9Bh495opeuRCEBBZFPDdG0g3xSpIFgNGJeJA?download=1'
x_open = pooch.retrieve(x_path,known_hash='c6acaf62051b81dfd3d5a4aa516d545615fd2597c8c38f4db4e571a621201878')
y_open = pooch.retrieve(y_path,known_hash='6265a5f0272e5427c823b95725b8aabbc48a9a97d7554fd5732e6c4b480f3ab3')

# load saved model's weights and biases to optionally accelerate
# the notebook's completion
u_path = path_data + 'ERb1IhuBFfZAjgcAF8N6pikBWet4WBZtheh9zvOyH7QNUg?download=1'
u_open = pooch.retrieve(u_path, known_hash='9dfb659526fa686062057de77bcb2d14ca46fc11212e799e9a5ec7175679b756')
a_path = path_data + 'EQuCPW4q2ilFpFQSZz_UDXYBjR2ZRRcJCDFW-EbI_xWKGg?download=1'
a_open = pooch.retrieve(a_path, known_hash='c6930d908c21785217b20895e19dbf6ddf6cd312b95b331ba590e2267d03e6f4')
l_path = path_data + 'EVb6Gw0ri6JJsazJ6mxUMXUBm4ndfmMjoj0M9gdIPDkd3A?download=1'
l_open = pooch.retrieve(l_path, known_hash = 'ab7e9013631d29806a8f73fbf7275cb37ca3561cb0d0c532107ea00ae496b187')

# build pointers to the data in the notebook using Xarray
x = (
    xr.open_dataset(x_open)
    .set_coords(["forecast_reference_time","t", "station_id"])
    .set_xindex("forecast_reference_time")
    .to_array("var")
    .transpose("s","var")
)
y = (
    xr.open_dataset(y_open)
    .set_coords(["forecast_reference_time","t"])
    .set_xindex("forecast_reference_time")
    .to_array("var")
    .transpose("s","var")
)
Downloading data from 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/EdAG3RBBgk5Kmvo54RPgT2kBp-NJqqGF6Il-gTmh9DbdeA?download=1' to file '/root/.cache/pooch/9cf67a523eb07cc3cce65fc8ca1e7c3a-EdAG3RBBgk5Kmvo54RPgT2kBp-NJqqGF6Il-gTmh9DbdeA'.
Downloading data from 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/EdVQCVKqnb9Bh495opeuRCEBBZFPDdG0g3xSpIFgNGJeJA?download=1' to file '/root/.cache/pooch/284b1ede27900e4591c2a3de05037b03-EdVQCVKqnb9Bh495opeuRCEBBZFPDdG0g3xSpIFgNGJeJA'.
Downloading data from 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/ERb1IhuBFfZAjgcAF8N6pikBWet4WBZtheh9zvOyH7QNUg?download=1' to file '/root/.cache/pooch/fb563202f2c0995fb7fe9938ffdbcb6e-ERb1IhuBFfZAjgcAF8N6pikBWet4WBZtheh9zvOyH7QNUg'.
Downloading data from 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/EQuCPW4q2ilFpFQSZz_UDXYBjR2ZRRcJCDFW-EbI_xWKGg?download=1' to file '/root/.cache/pooch/5a7910e0daed054648ef2dd170a95721-EQuCPW4q2ilFpFQSZz_UDXYBjR2ZRRcJCDFW-EbI_xWKGg'.
Downloading data from 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/EVb6Gw0ri6JJsazJ6mxUMXUBm4ndfmMjoj0M9gdIPDkd3A?download=1' to file '/root/.cache/pooch/b410f061e36b81b53fabc4a1cdccbf61-EVb6Gw0ri6JJsazJ6mxUMXUBm4ndfmMjoj0M9gdIPDkd3A'.
In [8]:
x
Out[8]:
<xarray.DataArray (s: 23787181, var: 11)>
array([[ 5.384186  ,  3.8263855 ,  1.5578003 , ...,  0.70710677,
        -0.8713402 ,  0.49067938],
       [11.651154  ,  9.337921  ,  2.3132324 , ...,  0.70710677,
        -0.8713402 ,  0.49067938],
       [11.587677  ,  9.549835  ,  2.0378418 , ...,  0.70710677,
        -0.8713402 ,  0.49067938],
       ...,
       [13.147949  ,  4.604645  ,  8.543304  , ...,  0.        ,
         0.07717546, -0.9970175 ],
       [ 8.020508  ,  0.8123169 ,  7.208191  , ...,  0.        ,
         0.07717546, -0.9970175 ],
       [12.391449  ,  7.90329   ,  4.488159  , ...,  0.        ,
         0.07717546, -0.9970175 ]], dtype=float32)
Coordinates:
  * forecast_reference_time  (s) datetime64[ns] 2016-06-01 ... 2022-10-01
    t                        (s) int32 ...
    station_id               (s) int32 ...
  * var                      (var) object 'coe:air_temperature_ensavg' ... 't...
Dimensions without coordinates: s
xarray.DataArray
  • s: 23787181
  • var: 11
  • 5.384 3.826 1.558 836.5 89.67 6.03 ... 120.0 1.0 0.0 0.07718 -0.997
    array([[ 5.384186  ,  3.8263855 ,  1.5578003 , ...,  0.70710677,
            -0.8713402 ,  0.49067938],
           [11.651154  ,  9.337921  ,  2.3132324 , ...,  0.70710677,
            -0.8713402 ,  0.49067938],
           [11.587677  ,  9.549835  ,  2.0378418 , ...,  0.70710677,
            -0.8713402 ,  0.49067938],
           ...,
           [13.147949  ,  4.604645  ,  8.543304  , ...,  0.        ,
             0.07717546, -0.9970175 ],
           [ 8.020508  ,  0.8123169 ,  7.208191  , ...,  0.        ,
             0.07717546, -0.9970175 ],
           [12.391449  ,  7.90329   ,  4.488159  , ...,  0.        ,
             0.07717546, -0.9970175 ]], dtype=float32)
    • forecast_reference_time
      (s)
      datetime64[ns]
      2016-06-01 ... 2022-10-01
      array(['2016-06-01T00:00:00.000000000', '2016-06-01T00:00:00.000000000',
             '2016-06-01T00:00:00.000000000', ..., '2022-10-01T00:00:00.000000000',
             '2022-10-01T00:00:00.000000000', '2022-10-01T00:00:00.000000000'],
            dtype='datetime64[ns]')
    • t
      (s)
      int32
      ...
      [23787181 values with dtype=int32]
    • station_id
      (s)
      int32
      ...
      [23787181 values with dtype=int32]
    • var
      (var)
      object
      'coe:air_temperature_ensavg' ......
      array(['coe:air_temperature_ensavg', 'coe:dew_point_temperature_ensavg',
             'coe:dew_point_depression_ensavg', 'coe:surface_air_pressure_ensavg',
             'coe:relative_humidity_ensavg', 'coe:water_vapor_mixing_ratio_ensavg',
             'coe:leadtime', 'time:cos_hourofday', 'time:sin_hourofday',
             'time:cos_dayofyear', 'time:sin_dayofyear'], dtype=object)
In [10]:
x.shape, y.shape
Out[10]:
((23787181, 11), (23787181, 5))

Part II. Pre-processing your data for deep learning¶

Following best machine learning practices, we recommend splitting your data into three folds:

  1. A training set, which is used to optimize the trainable parameters (weights and biases) of your machine learning model,
  2. A validation set, which is used to optimize the hyperparameters of your machine learning model, and check for underfitting/overfitting,
  3. A test set, which is used for independent testing after the trainable parameters and hyperparameters have been optimized.

To avoid correlations between each fold, we will split the data using non-overlapping time series. To be able to compare your results with ours, we encourage you to use the following (left-inclusive, right-exclusive, just like in Python) dates for the split:

  1. Training: Jan 1, 2017 to Dec 25, 2019
  2. Validation: Jan 1, 2020 to Dec 25, 2020
  3. Test: Jan 1, 2021 to Jan 1, 2022

💡 For all questions, you can write your own code or complete the proposed code by replacing the underscores with the appropriate script

Q1) Split your data into training, validation, and test sets¶

The data was loaded as an Xarray DataArray. Assuming you executed all of Part I's code cells, the inputs were loaded in x and the outputs were loaded in y. We recommend:

  1. Assigning the training inputs to train_x and the training outputs to train_y
  2. Assigning the validation inputs to val_x and the validation outputs to val_y
  3. Assigning the test inputs to test_x and the test outputs to test_y

Hints:

  1. You may want to check out the coordinates of x and y by simply typing x or y in an empty code cell and executing it.
  2. If you're confused on how to select an exact time-period, you can read up on Xarray's time-series data
  3. Xarray's method sel allows you to select data based labels instead of integers. If you prefer to select data based on integers, you can use the isel method.
In [6]:
# Here's an empty code cell to look at the data, etc.
# You can add or remove code and text cells via the "Insert" menu
In [14]:
train_sel = dict(forecast_reference_time=slice('2017-01-01','2019-12-25'))
val_sel = dict(forecast_reference_time=slice('2020-01-01','2020-12-25'))
test_sel = dict(forecast_reference_time=slice('2021-01-01','2022-01-01'))

train_x, train_y = x.sel(train_sel), y.sel(train_sel)
val_x, val_y = x.sel(val_sel), y.sel(val_sel)
test_x, test_y = x.sel(test_sel), y.sel(test_sel)
In [18]:
train_x
Out[18]:
<xarray.DataArray (s: 11142117, var: 11)>
array([[-2.2351379e+00, -1.0240204e+01,  8.0050659e+00, ...,
         7.0710677e-01,  9.9981350e-01,  1.9311870e-02],
       [ 1.4291077e+00, -6.3759460e+00,  7.8050537e+00, ...,
         7.0710677e-01,  9.9981350e-01,  1.9311870e-02],
       [-3.4679260e+00, -6.4877930e+00,  3.0198669e+00, ...,
         7.0710677e-01,  9.9981350e-01,  1.9311870e-02],
       ...,
       [ 6.6216431e+00, -1.7839050e+00,  8.4055481e+00, ...,
         1.2246469e-16,  9.9966848e-01, -2.5747914e-02],
       [ 1.8598328e+00, -2.5563354e+00,  4.4161682e+00, ...,
         1.2246469e-16,  9.9966848e-01, -2.5747914e-02],
       [ 4.0552368e+00, -4.5181274e-01,  4.5070496e+00, ...,
         1.2246469e-16,  9.9966848e-01, -2.5747914e-02]], dtype=float32)
Coordinates:
  * forecast_reference_time  (s) datetime64[ns] 2017-01-01 ... 2019-12-25T12:...
    t                        (s) int32 ...
    station_id               (s) int32 ...
  * var                      (var) object 'coe:air_temperature_ensavg' ... 't...
Dimensions without coordinates: s
xarray.DataArray
  • s: 11142117
  • var: 11
  • -2.235 -10.24 8.005 843.5 54.17 ... -1.0 1.225e-16 0.9997 -0.02575
    array([[-2.2351379e+00, -1.0240204e+01,  8.0050659e+00, ...,
             7.0710677e-01,  9.9981350e-01,  1.9311870e-02],
           [ 1.4291077e+00, -6.3759460e+00,  7.8050537e+00, ...,
             7.0710677e-01,  9.9981350e-01,  1.9311870e-02],
           [-3.4679260e+00, -6.4877930e+00,  3.0198669e+00, ...,
             7.0710677e-01,  9.9981350e-01,  1.9311870e-02],
           ...,
           [ 6.6216431e+00, -1.7839050e+00,  8.4055481e+00, ...,
             1.2246469e-16,  9.9966848e-01, -2.5747914e-02],
           [ 1.8598328e+00, -2.5563354e+00,  4.4161682e+00, ...,
             1.2246469e-16,  9.9966848e-01, -2.5747914e-02],
           [ 4.0552368e+00, -4.5181274e-01,  4.5070496e+00, ...,
             1.2246469e-16,  9.9966848e-01, -2.5747914e-02]], dtype=float32)
    • forecast_reference_time
      (s)
      datetime64[ns]
      2017-01-01 ... 2019-12-25T12:00:00
      array(['2017-01-01T00:00:00.000000000', '2017-01-01T00:00:00.000000000',
             '2017-01-01T00:00:00.000000000', ..., '2019-12-25T12:00:00.000000000',
             '2019-12-25T12:00:00.000000000', '2019-12-25T12:00:00.000000000'],
            dtype='datetime64[ns]')
    • t
      (s)
      int32
      ...
      [11142117 values with dtype=int32]
    • station_id
      (s)
      int32
      ...
      [11142117 values with dtype=int32]
    • var
      (var)
      object
      'coe:air_temperature_ensavg' ......
      array(['coe:air_temperature_ensavg', 'coe:dew_point_temperature_ensavg',
             'coe:dew_point_depression_ensavg', 'coe:surface_air_pressure_ensavg',
             'coe:relative_humidity_ensavg', 'coe:water_vapor_mixing_ratio_ensavg',
             'coe:leadtime', 'time:cos_hourofday', 'time:sin_hourofday',
             'time:cos_dayofyear', 'time:sin_dayofyear'], dtype=object)
In [ ]:
#@title A possible solution for Q1
train_sel = dict(forecast_reference_time=slice("2017-01-01","2019-12-25"))
val_sel = dict(forecast_reference_time=slice("2020-01-01","2020-12-25"))
test_sel = dict(forecast_reference_time=slice("2021-01-01","2022-01-01"))

train_x, train_y = x.sel(train_sel), y.sel(train_sel)
val_x, val_y = x.sel(val_sel), y.sel(val_sel)
test_x, test_y = x.sel(test_sel), y.sel(test_sel)
In [22]:
train_x.plot.hist()
Out[22]:
(array([95947989., 15473181.,        0.,        0.,        0.,        0.,
          102225.,  1682193.,  3284147.,  6073552.]),
 array([ -49.68917847,   55.71460724,  161.11839294,  266.52218628,
         371.92596436,  477.32977295,  582.73352051,  688.1373291 ,
         793.5411377 ,  898.94488525, 1004.34869385]),
 <BarContainer object of 10 artists>)
No description has been provided for this image

Q2) Explore your training data to gain intuition on the meteorological variables to post-process and their distributions¶

You may use the built-in Xarray histogram plotting method to quickly visualize your x_train and y_train DataArrays. The code to complete gives you an example of how to visualize one variable, while the solution visualizes all features.

  • The bins argument sets the number of bins
  • The alpha argument sets the transparency (between 0 and 1)
  • The color argument sets the histogram's color (e.g., k for black and b for blue)

image.png

List of predictors (or features or inputs) and predictands (or targets or outputs) used in this post-processing example

Table 1 of Zanetta et al. (2022)

In [ ]:
# Feel free to explore the data using this cell
In [24]:
Nbin = 50 # Number of bins
fz = 12 # Fontsize

train_x[:,0].plot.hist(bins=Nbin, alpha=0.5, color='r',
                       label='Forecast to Post-process (Feature)',
                       figsize=(15,5))
train_y[:,0].plot.hist(bins=Nbin, alpha=0.5, color='b',
                        label='Observations (Target)')

plt.title('Air temperature (C)', fontsize=fz); 
plt.legend(fontsize=fz)
Out[24]:
<matplotlib.legend.Legend at 0x7fe5841c9580>
No description has been provided for this image
In [26]:
train_y
Out[26]:
<xarray.DataArray (s: 11142117, var: 5)>
array([[ 8.0000001e-01, -1.9355097e+01,  8.1570001e+02,  2.0400000e+01,
         1.0083307e+00],
       [-5.9000001e+00, -7.2432714e+00,  9.6809998e+02,  9.0199997e+01,
         2.2889206e+00],
       [-5.6999998e+00, -5.8053269e+00,  9.4070001e+02,  9.9199997e+01,
         2.6319027e+00],
       ...,
       [-1.0000000e-01, -1.9647184e-01,  9.7809998e+02,  9.9300003e+01,
         3.8528087e+00],
       [-2.0999999e+00, -3.9984221e+00,  9.0679999e+02,  8.6800003e+01,
         3.1331336e+00],
       [ 2.0000000e+00, -4.2977437e-01,  9.8500000e+02,  8.3900002e+01,
         3.7601204e+00]], dtype=float32)
Coordinates:
  * forecast_reference_time  (s) datetime64[ns] 2017-01-01 ... 2019-12-25T12:...
    t                        (s) int32 ...
  * var                      (var) object 'obs:air_temperature' ... 'obs:wate...
Dimensions without coordinates: s
xarray.DataArray
  • s: 11142117
  • var: 5
  • 0.8 -19.36 815.7 20.4 1.008 -5.9 ... 3.133 2.0 -0.4298 985.0 83.9 3.76
    array([[ 8.0000001e-01, -1.9355097e+01,  8.1570001e+02,  2.0400000e+01,
             1.0083307e+00],
           [-5.9000001e+00, -7.2432714e+00,  9.6809998e+02,  9.0199997e+01,
             2.2889206e+00],
           [-5.6999998e+00, -5.8053269e+00,  9.4070001e+02,  9.9199997e+01,
             2.6319027e+00],
           ...,
           [-1.0000000e-01, -1.9647184e-01,  9.7809998e+02,  9.9300003e+01,
             3.8528087e+00],
           [-2.0999999e+00, -3.9984221e+00,  9.0679999e+02,  8.6800003e+01,
             3.1331336e+00],
           [ 2.0000000e+00, -4.2977437e-01,  9.8500000e+02,  8.3900002e+01,
             3.7601204e+00]], dtype=float32)
    • forecast_reference_time
      (s)
      datetime64[ns]
      2017-01-01 ... 2019-12-25T12:00:00
      array(['2017-01-01T00:00:00.000000000', '2017-01-01T00:00:00.000000000',
             '2017-01-01T00:00:00.000000000', ..., '2019-12-25T12:00:00.000000000',
             '2019-12-25T12:00:00.000000000', '2019-12-25T12:00:00.000000000'],
            dtype='datetime64[ns]')
    • t
      (s)
      int32
      ...
      [11142117 values with dtype=int32]
    • var
      (var)
      object
      'obs:air_temperature' ... 'obs:w...
      array(['obs:air_temperature', 'obs:dew_point_temperature',
             'obs:surface_air_pressure', 'obs:relative_humidity',
             'obs:water_vapor_mixing_ratio'], dtype=object)
In [25]:
#@title A possible solution for Q2
Nbin = 60 # Number of bins

# Loops over all features
for ix, xvar in enumerate(train_x['var']):
  # Plots the histogram of the feature in black 
  train_x[:,ix].plot.hist(bins=Nbin, alpha=0.5, color='k',
                          label='Forecast to Post-process (Feature)',
                          figsize=(15,2.5))
  
  # If the feature has a corresponding observation,
  # plots its histogram in blue
  iy = [iy for iy in range(len(train_y['var'])) if \
        str(train_y['var'][iy].values)[4:]==str(xvar.values)[4:-7]]
  if iy: train_y[:,iy].plot.hist(bins=Nbin, alpha=0.5, color='b',
                                 label='Observations (Target)')
  
  # Add one title per feature
  plt.title('Feature: '+str(xvar.values)[4:]); 
  plt.legend(fontsize=16)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Q3) Normalize your features¶

You may use a standard normalization setting the mean to 0 and the standard deviation to 1 for all training, validation, and test features.

⚠ You can only use the training set to calculate your feature's statistical moments (mean, standard deviation, etc.); otherwise, you are at risk of validation/test data leakage.

Hints:

  1. Make sure you take the mean along the sample (s) dimension, which can be done by specifying the axis (an integer) or the name of the dimension (a string).
  2. The standard normalization relies on the standard score or "Z-score", which you can read more about at this link
  3. The documentation for Xarray's DataArray's mean method is here and that of the standard deviation method is here.
In [29]:
# You can normalize your input data using the standard score
# by completing the code below
train_x_mean = train_x.mean(dim='s')
train_x_std = train_x.std(dim='s')

train_x = (train_x - train_x_mean) / train_x_std
val_x = (val_x - train_x_mean) / train_x_std
test_x = (test_x - train_x_mean) / train_x_std
In [ ]:
#@title A possible solution for Q3
train_x_mean = train_x.mean("s")
train_x_std = train_x.std("s")

train_x = (train_x - train_x_mean) / train_x_std
val_x = (val_x - train_x_mean) / train_x_std
test_x = (test_x - train_x_mean) / train_x_std

Below is code to:

  • Convert the Xarray DataArrays (train_x,train_y) and (val_x,val_y) to Torch Tensors versions of the training and validation data: train and val
  • Build your data generator (or data loader) for the training set train_dl and for the validation set valid_dl so that you can seemlessly feed the train and valid datasets to the neural networks you train.

You can double click to check out the source code but we will not be focusing on the data generator in this notebook as it's fairly standard.

In [30]:
BATCH_SIZE = 1024 # You can choose your batch size here; we recommend 1024
In [32]:
#@title Source code for Torch data loaders

class Data(Dataset):
    def __init__(self, x, y):

        self.x = torch.tensor(x.values).to("cuda:0")
        self.y = torch.tensor(y.values).to("cuda:0")
        self.station_id = torch.tensor(x.station_id.values, dtype=torch.int32).to("cuda:0")
        self.y_coords = y.coords

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return (self.x[idx], self.station_id[idx]), self.y[idx]


class DataLoader(object):
    def __init__(self, dataset, batch_size, shuffle=False):
        self.dataset = dataset
        self.dataset_len = self.dataset.x.shape[0]
        self.batch_size = batch_size
        self.shuffle = shuffle
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        if self.shuffle:
            r = torch.randperm(self.dataset_len)
            self.dataset.x = self.dataset.x[r]
            self.dataset.station_id = self.dataset.station_id[r]
            self.dataset.y = self.dataset.y[r]
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        batch = self.dataset[self.i : self.i + self.batch_size]
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches
In [33]:
# Convert Xarray training/validation DataArray to Torch Tensors
train = Data(train_x, train_y)
val = Data(val_x, val_y)

# Build training/validation data loaders
train_dl = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val, batch_size=BATCH_SIZE * 8)

Part III. Physically-constrained neural network¶

Now comes the fun part! We are going to compare two methods to incorporate thermodynamic state equations (here, the ideal gas law and an integral of Clausius-Clapeyron) into neural networks post-processing surface weather 🌧

  1. A loss-constrained setting (d), in which the neural network is penalized for violating the laws of thermodynamics (soft constraint), and
  2. An architecture-constrained setting (c), in which the neural network is forced to follow the laws of thermodynamics (hard constraint).
  3. We also use an unconstrained setting (b) for reference (baseline model).

image.png

Comparison of different ways of enforcing physical constraints in a neural network post-processing temperature (T), dew-point temperature ($T_d$), pressure (P), relative humidity (RH), and water vapor mixing ratio (r).

Figure 2 of Zanetta et al. (2022)

Q4) Using the schematic above and the equations below, implement the physical constraints¶

This layer should take in temperature, temperature deficit, and pressure, and output temperature, dew-point temperature, pressure, relative humidity, and water vapor mixing ratio. Here, we'll practice directly on the observations, which already follow the physical constraints.

1) Dew-point temperature $T_{d}$ is related to the temperature T and the temperature deficit $T_{def}$ through its definition:

$T_{d} = T - T_{def}$

2) Water vapor mixing ratio r can be calculated as a function of pressure and dew-point temperature through the water vapor partial pressure, which can be approximated by integrating the Clausius-Clapeyron equation. Using the approximate integral used by MeteoSwiss, the water vapor partial pressure is given by:

$e \left[ hPa \right] = c \times \exp \left( \frac {a T_{d}}{b+T_{d}} \right) $,

where $a\approx 17.368$, $b\approx238.83$, and $c\approx6.107$hPa for $T \geq 0$ and
$a\approx 17.856$, $b\approx245.52$, and $c\approx6.108$hPa for $T < 0$.

Then, using the ideal gas law, we can calculate the water vapor mixing ratio [in g/kg] from the water vapor partial pressure $e$ and the total air pressure $p$ :

$r \left[g/kg\right] = 1000 \times 0.622 \times \frac{e}{p-e}$

3) Relative humidity RH is defined as the ratio of the water vapor partial pressure $e$ to its saturation value $e_{s}$:

$RH \left[\%\right] = 100 \times e/e_{s}$.

Note that saturation water vapor pressure can be calculated by substituting the dew-point temperature for its absolute value in the formula for water vapor partial pressure:

$e_{s} \left[ hPa \right] = c \times \exp \left( \frac {a T}{b+T} \right) $

In [34]:
# Assume we already have access to temperature, temperature deficit,
# and surface air pressure

temperature = train_y[:,0]
temperature_deficit = temperature - train_y[:,1] # GM: if dew point is already 
                                                 # given why calculate it?
air_pressure = train_y[:,2]
In [41]:
# Can you calculate water vapor mixing ratio and relative humidity? 
# You may fill out the code below
# and use numpy's where function 
# (https://numpy.org/doc/stable/reference/generated/numpy.where.html),
# which return values depending on conditions. 
# np.where is useful for the above version of the water vapor partial pressure, 
# which is defined based on the absolute temperature's sign. 

# 1) Calculate the dewpoint temperature
temperature_dp = temperature - temperature_deficit

# 2) Calculate water vapor mixing ratio
wv_pressure = np.where(
    temperature >= 0.0,
    6.107 * np.exp((17.368*temperature_dp) / (238.83+temperature_dp)),
    6.108 * np.exp((17.856*temperature_dp) / (245.52+temperature_dp)),
    )
wv_mix_ratio = 1000*0.622 * (wv_pressure / (air_pressure-wv_pressure))

# 3) Relative humidity
wv_satpressure = np.where(
    temperature >= 0.0,
    6.107 * np.exp((17.368*temperature) / (238.83+temperature)),
    6.108 * np.exp((17.856*temperature) / (245.52+temperature)),
    )
relative_humidity = 100*wv_pressure/wv_satpressure
In [ ]:
# Or use this empty cell
In [42]:
#@title This cell should output a very small number (absolute value less than 1e-6) if you correctly implemented the physical constraints
print(f'1) The squared residual of the dew-point temperature is {np.mean((temperature_dp-train_y[:,1]).values):.1e} K^2')
print(f'2) The squared residual of the water vapor mixing ratio is {np.mean((wv_mix_ratio-train_y[:,4]).values):.1e} (g/kg)^2')
print(f'3) The squared residual of the relative humidity is {np.mean((relative_humidity-train_y[:,3]).values):.1e} (%)^2')
1) The squared residual of the dew-point temperature is -1.2e-10 K^2
2) The squared residual of the water vapor mixing ratio is 1.7e-07 (g/kg)^2
3) The squared residual of the relative humidity is -3.4e-07 (%)^2
In [43]:
#@title A possible solution for Q4

# 1) Calculate the dewpoint temperature
temperature_dp = temperature - temperature_deficit

# 2) Calculate water vapor mixing ratio
wv_pressure = np.where(
    temperature >= 0.0,
    6.107 * np.exp((17.368 * temperature_dp) / (temperature_dp + 238.83)),
    6.108 * np.exp((17.856 * temperature_dp) / (temperature_dp + 245.52)),
    )
wv_mix_ratio = 622.0 * (wv_pressure / (air_pressure - wv_pressure))

# 3) Relative humidity
wv_satpressure = np.where(
    temperature >= 0.0,
    6.107 * np.exp((17.368 * temperature) / (temperature + 238.83)),
    6.108 * np.exp((17.856 * temperature) / (temperature + 245.52)),
    )
relative_humidity = 100.0 * wv_pressure / wv_satpressure

Converting our physical constraints to physically-constrained layers mostly requires several steps:

  1. Writing a custom PhysicsLayers class that inherits the properties of Torch's base module class, and
  2. Converting our physical constraints from standard Python/numpy to PyTorch, remembering that temperature, temperature deficit, and surface air pressure are direct outputs of the neural network and inputs of our physically-constrained layer.

It typically takes more than the 30 minutes we have for this notebook, so directly give you the source code for the PhysicsLayer below.

In [44]:
#@title Source code for the physically-constrained layer: Run it to proceed, double click to see the code
class PhysicsLayer(nn.Module):
    def __init__(self):
        super(PhysicsLayer, self).__init__()

    def forward(self, direct):
        t, t_def, p = direct[:, 0], direct[:, 1], direct[:, 2]
        # Below, the rectified linear unit ensures the positivity of
        # the deficit temperature, which is not guaranteed when it is
        # a direct output of the neural network 
        t_d = t - torch.relu(t_def)
        e_s = torch.where(
            t >= 0.0,
            6.107 * torch.exp((17.368 * t) / (t + 238.83)),
            6.108 * torch.exp((17.856 * t) / (t + 245.52)),
        )
        e = torch.where(
            t >= 0.0,
            6.107 * torch.exp((17.368 * t_d) / (t_d + 238.83)),
            6.108 * torch.exp((17.856 * t_d) / (t_d + 245.52)),
        )
        rh = e / e_s * 100.0
        r = 622.0 * (e / (p - e))
        pred = torch.stack([t, t_d, p, rh, r], dim=1)
        return pred

Additionally, below is the source code for all of our neural networks' architectures, which are wrapped in the Net class. In this class, we initialize the output bias manually to values that are in the same order of magnitude as our targets to facilitate training. Optionally, if the constraint argument is set to True, the architecture may include the PhysicsLayer encoding the physical constraints.

In [45]:
#@title Source code for the neural networks' architectures' wrapper class: Run it to proceed, double click to see the code
class Net(nn.Module):

    # Initialize the biases with mean value for each variable to facilitate learning
    out_bias = [15.0, 10.0, 900.0, 70.0, 5.0]  # t, t_d, p, rh, r
    out_bias_constrained = [15.0, 5.0, 900]  # t, t_def, p

    # Initializes the neural network architecture
    def __init__(self, in_size, n_stations, embedding_size, l1, l2, constraint=False):
        super(Net, self).__init__()
        self.embedding = nn.Embedding(n_stations, embedding_size)
        self.l1 = nn.Linear(in_size + embedding_size, l1)
        self.l2 = nn.Linear(l1, l2)
        # Physically-constrained version
        if constraint:
            self.out = nn.Sequential(nn.Linear(l2, 3), PhysicsLayer())
            self.out[0].bias = nn.Parameter(torch.Tensor(self.out_bias_constrained))
        # Unconstrained version
        else:
            self.out = nn.Linear(l2, 5)
            self.out.bias = nn.Parameter(torch.Tensor(self.out_bias))

    # Defines a forward pass through the neural network
    def forward(self, x, station_id):
        station_embedding = self.embedding(station_id)
        x = torch.concat([x, station_embedding], dim=-1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        out = self.out(x)
        return out

Part IV. Multi-task loss function¶

Even in the unconstrained case, our loss function is multi-task because it optimizes the post-processing of multiple variables (mixing ratio, relative humidity, etc.). Choosing how to weigh different terms of the loss function when a regression has multiple outputs with different units and uncertainties is tricky and time-consuming 😖

Therefore, we follow the multi-task framework of:

Kendall, A., Gal, Y., & Cipolla, R. (2018). Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 7482-7491).

and weigh each term of the loss function by a learned uncertainty term, which is not allowed to set the term to zero (even if it's hard to learn!). We refer you to the paper above for the mathematical details, summarized in section 2c of Zanetta et al.'s preprint.

Below, we define the MultiTaskLoss class to implement this uncertainty-weighted loss function, and optionally adds the physics-based loss term computed with physical_penalty. This term balances performance and physical consistency using a hyperparameter $\alpha$:

$loss = (1 - \alpha) \times (multi-task\ loss) + \alpha \times (physics\ loss)$

In [46]:
#@title Source code for the multitask loss: Run it to proceed, double click to see the code
class MultiTaskLoss(nn.Module):
    def __init__(self, alpha=0.0, mask=None, log_var_init=None):
        super(MultiTaskLoss, self).__init__()
        self.alpha = alpha
        self.mask = [True] * 5 if mask is None else mask

        log_var = torch.zeros(5) if log_var_init is None else log_var_init

        self.log_var = nn.Parameter(log_var)

        self.hp = {"alpha": alpha}

    def forward(self, pred, y):
        loss = torch.mean((pred - y) ** 2, axis=0)
        loss = torch.exp(-self.log_var) * loss + self.log_var
        loss = torch.sum(loss[self.mask])
        rh_res, r_res = physical_penalty(pred)
        physics_loss = rh_res / torch.var(y[:, 3]) + r_res / torch.var(y[:, 4])
        if self.alpha > 0.0:
            loss = (1 - self.alpha) * loss + self.alpha * physics_loss
        return loss, physics_loss


def physical_penalty(pred):
    t, t_d, p, rh, r = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3], pred[:, 4]
    e = torch.where(
        t >= 0.0,
        6.107 * torch.exp((17.368 * t_d) / (t_d + 238.83)),
        6.108 * torch.exp((17.856 * t_d) / (t_d + 245.52)),
    )
    e_s = torch.where(
        t >= 0.0,
        6.107 * torch.exp((17.368 * t) / (t + 238.83)),
        6.108 * torch.exp((17.856 * t) / (t + 245.52)),
    )
    rh_derived = e / (e_s + 1e-5) * 100.0
    r_derived = 622.0 * (e / (p - e))

    return (
        torch.mean((rh_derived - rh) ** 2),
        torch.mean((r_derived - r) ** 2),
    )

Part V. Model training¶

In [47]:
LR = 0.0008 # Define your learning rate here (recommended = 8e-4)
In [48]:
#@title Source code for the training/validation functions: Run it to proceed, double click to see the code

def training_step(model, loss_fn, optimizer, train_dataloader):
    model.train(True)
    loss_fn.train(True)
    running_loss = 0.0
    num_batches = len(train_dataloader)
    iterator = enumerate(train_dataloader)
    for i, (X, y) in iterator:
        pred = model(*X)
        loss, p = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / num_batches


def validation_step(model, loss_fn, val_dataloader):
    model.train(False), loss_fn.train(False)
    val_loss = 0.0
    val_p = 0.0
    val_mae = 0.0
    with torch.no_grad():
        iterator = val_dataloader
        for X, y in iterator:
            pred = model(*X)
            loss, p = loss_fn(pred, y)
            val_loss += loss.item()
            val_p += p.item()
            val_mae += torch.mean(torch.abs(pred - y), dim=0)
    val_loss /= len(val_dataloader)
    val_p /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    return val_loss, val_p, val_mae


def fit(model, loss_fn, train_dl, val_dl, max_epochs=20, max_patience=5):
    optimizer = optim.Adam(
        chain(model.parameters(), loss_fn.parameters()), lr=LR
    )
    best_val_nmae = torch.inf
    val_y_std = val.y.std(dim=0)
    for epoch in range(max_epochs):

        train_loss = training_step(model, loss_fn, optimizer, train_dl)
        val_loss, val_p, val_mae = validation_step(model, loss_fn, val_dl)
        val_nmae = torch.mean((val_mae / val_y_std))

        if val_nmae < best_val_nmae:
            best_val_nmae = val_nmae
            best_model = model
            patience = 0
        else:
            patience += 1
            if patience == max_patience:
                break

        print(
            f"{epoch+1:<2}{'':>4}loss: {train_loss:<10.4}val_loss: {val_loss:<10.4}"
            f"val_t_mae: {val_mae[0]:<10.4} val_rh_mae: {val_mae[3]:<10.4}"
            f"val_nmae: {val_nmae:<10.4} val_p: {val_p:<10.4}"
        )

    return best_model
In [49]:
torch.device("cuda:0") # Enables GPU runtime in Google Colab
Out[49]:
device(type='cuda', index=0)

Q5) Train three neural networks: unconstrained, architecture-constrained, and loss-constrained.¶

Note: The training may take very long and can easily exceed 30 minutes per network if you use more than 15 epochs. Therefore, we provide trained model after Q5's solution.

Hints:

  1. You can instantiate a neural network using the Net class we defined above, and which takes the following argument: image.png
  • in_size: The input size (here we have 11 features from the NWP model)
  • n_stations: The number of stations (here 131)
  • embedding_size: The embedding size, which you can read more about here
  • l1: The number of units in layer 1 (and NOT the L1 regularization coefficient)
  • l2: The number of units in layer 2 (and NOT the L2 regularization coefficient)
  • constraint: A boolean defining whether a physics-constraint layer is used or not
  1. After defining the loss, you may use the fit function defined above:

image.png

fit takes the following arguments:

  • model: The machine learning model to train
  • loss_fn: The loss function
  • train_dl: The training data loader
  • val_dl: The validation data loader
  • max_epochs: The maximal number of epochs
  • max_patience: The maximal patience (how many epochs we wait until stopping the training if the validation loss does not improve)

Note that as you fit your neural network, the function will provide the following diagnostics:

  • loss: The value of the loss function you chose, averaged over the training set
  • val_loss: The value of the loss function you chose, averaged over the validation set
  • val_t_mae: The mean absolute error of the temperature (in K) over the validation set
  • val_rh_mae: The mean absolute error of the relative humidity (in %) over the validation set
  • val_nmae: The normalized mean absolute error (using the standard deviation of each variable to normalize its MAE) over the validation set
  • val_p: The physical violation term, i.e., the residual from the thermodynamic state equations, averaged over the validation set.
In [50]:
# Define your hyperparameters here
# We give the recommended value for each hyperparameter below

ALPHA = 0.995 # Recommended for the loss-constrained neural networks
# The hyperparameters below are recommended across models
EMBEDDING_SIZE=5
MAX_EPOCHS = 20 # 10-15 to finish notebook on time but at least 20 for performance
N_INPUTS = 11
N_STATIONS = 131
PATIENCE = 5
UNITS_L1 = 128
UNITS_L2 = 256
In [ ]:
# Uncomment below to use multiple threads
# torch.set_num_threads(2)
In [53]:
# Unconstrained neural network: Definition. #in_size, n_stations, embedding_size, l1, l2, constraint=False
unconstrained_NN = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=False
    ).to("cuda:0")
loss = MultiTaskLoss(alpha=0).to("cuda:0")
In [54]:
# Unconstrained neural network: Training
unconstrained_NN = fit(unconstrained_NN, loss, train_dl, val_dl, 
                       max_epochs=MAX_EPOCHS,
                       max_patience=PATIENCE)
1     loss: 124.1     val_loss: 14.19     val_t_mae: 1.562      val_rh_mae: 9.747     val_nmae: 0.2201     val_p: 0.04771   
2     loss: 13.7      val_loss: 13.49     val_t_mae: 1.511      val_rh_mae: 9.602     val_nmae: 0.216      val_p: 0.04005   
3     loss: 13.5      val_loss: 13.41     val_t_mae: 1.49       val_rh_mae: 9.458     val_nmae: 0.2138     val_p: 0.03094   
4     loss: 13.43     val_loss: 13.36     val_t_mae: 1.488      val_rh_mae: 9.404     val_nmae: 0.2126     val_p: 0.02828   
5     loss: 13.38     val_loss: 13.38     val_t_mae: 1.483      val_rh_mae: 9.438     val_nmae: 0.213      val_p: 0.03341   
6     loss: 13.34     val_loss: 13.38     val_t_mae: 1.469      val_rh_mae: 9.304     val_nmae: 0.2127     val_p: 0.06609   
7     loss: 13.31     val_loss: 13.32     val_t_mae: 1.468      val_rh_mae: 9.289     val_nmae: 0.211      val_p: 0.04206   
8     loss: 13.29     val_loss: 13.37     val_t_mae: 1.472      val_rh_mae: 9.281     val_nmae: 0.2107     val_p: 0.02972   
9     loss: 13.27     val_loss: 13.31     val_t_mae: 1.47       val_rh_mae: 9.281     val_nmae: 0.2107     val_p: 0.03096   
10    loss: 13.25     val_loss: 13.35     val_t_mae: 1.469      val_rh_mae: 9.325     val_nmae: 0.2118     val_p: 0.03391   
11    loss: 13.24     val_loss: 13.31     val_t_mae: 1.475      val_rh_mae: 9.323     val_nmae: 0.2115     val_p: 0.0292    
12    loss: 13.23     val_loss: 13.33     val_t_mae: 1.484      val_rh_mae: 9.214     val_nmae: 0.2106     val_p: 0.03745   
13    loss: 13.21     val_loss: 13.31     val_t_mae: 1.464      val_rh_mae: 9.222     val_nmae: 0.21       val_p: 0.02436   
14    loss: 13.2      val_loss: 13.32     val_t_mae: 1.47       val_rh_mae: 9.184     val_nmae: 0.2099     val_p: 0.0226    
15    loss: 13.19     val_loss: 13.29     val_t_mae: 1.466      val_rh_mae: 9.25      val_nmae: 0.2108     val_p: 0.03461   
16    loss: 13.18     val_loss: 13.28     val_t_mae: 1.454      val_rh_mae: 9.166     val_nmae: 0.2094     val_p: 0.02399   
17    loss: 13.17     val_loss: 13.32     val_t_mae: 1.459      val_rh_mae: 9.23      val_nmae: 0.2108     val_p: 0.03517   
18    loss: 13.16     val_loss: 13.33     val_t_mae: 1.468      val_rh_mae: 9.195     val_nmae: 0.2098     val_p: 0.02806   
19    loss: 13.15     val_loss: 13.28     val_t_mae: 1.451      val_rh_mae: 9.17      val_nmae: 0.2091     val_p: 0.02733   
20    loss: 13.15     val_loss: 13.36     val_t_mae: 1.46       val_rh_mae: 9.244     val_nmae: 0.2122     val_p: 0.03072   
In [55]:
# Architecture-constrained neural network: Definition
architecture_constrained_NN = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=True
    ).to("cuda:0")
loss = MultiTaskLoss(alpha=0).to("cuda:0")
In [56]:
# Architecture-constrained neural network: Training
architecture_constrained_NN = fit(architecture_constrained_NN, loss, 
                                  train_dl, val_dl,
                                  max_epochs=MAX_EPOCHS,
                                  max_patience=PATIENCE)
1     loss: 115.5     val_loss: 14.11     val_t_mae: 1.539      val_rh_mae: 9.664     val_nmae: 0.2185     val_p: 2.986e-11 
2     loss: 13.6      val_loss: 13.47     val_t_mae: 1.505      val_rh_mae: 9.463     val_nmae: 0.2138     val_p: 3.02e-11  
3     loss: 13.45     val_loss: 13.46     val_t_mae: 1.506      val_rh_mae: 9.325     val_nmae: 0.2121     val_p: 3.464e-11 
4     loss: 13.38     val_loss: 13.4      val_t_mae: 1.484      val_rh_mae: 9.273     val_nmae: 0.2109     val_p: 3.117e-11 
5     loss: 13.34     val_loss: 13.37     val_t_mae: 1.496      val_rh_mae: 9.281     val_nmae: 0.2113     val_p: 3.324e-11 
6     loss: 13.31     val_loss: 13.36     val_t_mae: 1.479      val_rh_mae: 9.45      val_nmae: 0.2129     val_p: 2.995e-11 
7     loss: 13.29     val_loss: 13.38     val_t_mae: 1.463      val_rh_mae: 9.244     val_nmae: 0.2102     val_p: 3.17e-11  
8     loss: 13.27     val_loss: 13.35     val_t_mae: 1.477      val_rh_mae: 9.251     val_nmae: 0.2103     val_p: 3.232e-11 
9     loss: 13.25     val_loss: 13.32     val_t_mae: 1.466      val_rh_mae: 9.306     val_nmae: 0.2101     val_p: 2.974e-11 
10    loss: 13.23     val_loss: 13.3      val_t_mae: 1.455      val_rh_mae: 9.174     val_nmae: 0.2091     val_p: 3.12e-11  
11    loss: 13.22     val_loss: 13.3      val_t_mae: 1.459      val_rh_mae: 9.205     val_nmae: 0.2092     val_p: 3.195e-11 
12    loss: 13.21     val_loss: 13.33     val_t_mae: 1.467      val_rh_mae: 9.328     val_nmae: 0.2116     val_p: 3.074e-11 
13    loss: 13.2      val_loss: 13.32     val_t_mae: 1.458      val_rh_mae: 9.38      val_nmae: 0.211      val_p: 2.937e-11 
14    loss: 13.19     val_loss: 13.37     val_t_mae: 1.472      val_rh_mae: 9.271     val_nmae: 0.2111     val_p: 3.2e-11   
15    loss: 13.18     val_loss: 13.27     val_t_mae: 1.448      val_rh_mae: 9.209     val_nmae: 0.2088     val_p: 3.129e-11 
16    loss: 13.17     val_loss: 13.33     val_t_mae: 1.462      val_rh_mae: 9.432     val_nmae: 0.2122     val_p: 2.947e-11 
17    loss: 13.16     val_loss: 13.31     val_t_mae: 1.457      val_rh_mae: 9.155     val_nmae: 0.2087     val_p: 3.22e-11  
18    loss: 13.15     val_loss: 13.28     val_t_mae: 1.452      val_rh_mae: 9.174     val_nmae: 0.2084     val_p: 3.201e-11 
19    loss: 13.15     val_loss: 13.32     val_t_mae: 1.468      val_rh_mae: 9.297     val_nmae: 0.2116     val_p: 3.118e-11 
20    loss: 13.14     val_loss: 13.26     val_t_mae: 1.453      val_rh_mae: 9.149     val_nmae: 0.2081     val_p: 3.071e-11 
In [59]:
# Loss-constrained neural network: Definition
loss_constrained_NN = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=False
    ).to("cuda:0")
constrained_loss = MultiTaskLoss(alpha=ALPHA).to("cuda:0")
In [60]:
# Loss-constrained neural network: Definition
loss_constrained_NN = fit(loss_constrained_NN, constrained_loss,
                          train_dl, val_dl,
                          max_epochs=MAX_EPOCHS,
                          max_patience=PATIENCE)
1     loss: 0.6702    val_loss: 0.07238   val_t_mae: 1.536      val_rh_mae: 9.706     val_nmae: 0.2187     val_p: 0.001698  
2     loss: 0.06928   val_loss: 0.06866   val_t_mae: 1.501      val_rh_mae: 9.627     val_nmae: 0.216      val_p: 0.00115   
3     loss: 0.06831   val_loss: 0.06896   val_t_mae: 1.494      val_rh_mae: 9.531     val_nmae: 0.2148     val_p: 0.001639  
4     loss: 0.06795   val_loss: 0.06832   val_t_mae: 1.48       val_rh_mae: 9.475     val_nmae: 0.2133     val_p: 0.001244  
5     loss: 0.06772   val_loss: 0.0681    val_t_mae: 1.478      val_rh_mae: 9.432     val_nmae: 0.2129     val_p: 0.001024  
6     loss: 0.06755   val_loss: 0.06823   val_t_mae: 1.472      val_rh_mae: 9.404     val_nmae: 0.2124     val_p: 0.001358  
7     loss: 0.06741   val_loss: 0.06814   val_t_mae: 1.469      val_rh_mae: 9.359     val_nmae: 0.2117     val_p: 0.001327  
8     loss: 0.06728   val_loss: 0.06775   val_t_mae: 1.457      val_rh_mae: 9.347     val_nmae: 0.211      val_p: 0.001085  
9     loss: 0.06718   val_loss: 0.06781   val_t_mae: 1.458      val_rh_mae: 9.323     val_nmae: 0.2109     val_p: 0.0008657 
10    loss: 0.06708   val_loss: 0.06772   val_t_mae: 1.46       val_rh_mae: 9.334     val_nmae: 0.2108     val_p: 0.001047  
11    loss: 0.06699   val_loss: 0.06813   val_t_mae: 1.466      val_rh_mae: 9.331     val_nmae: 0.211      val_p: 0.001563  
12    loss: 0.06691   val_loss: 0.06758   val_t_mae: 1.458      val_rh_mae: 9.308     val_nmae: 0.2104     val_p: 0.0009573 
13    loss: 0.06682   val_loss: 0.06757   val_t_mae: 1.463      val_rh_mae: 9.324     val_nmae: 0.2108     val_p: 0.0008462 
14    loss: 0.06675   val_loss: 0.06917   val_t_mae: 1.451      val_rh_mae: 9.246     val_nmae: 0.2094     val_p: 0.002636  
15    loss: 0.06669   val_loss: 0.06748   val_t_mae: 1.454      val_rh_mae: 9.254     val_nmae: 0.2097     val_p: 0.001036  
16    loss: 0.06663   val_loss: 0.06763   val_t_mae: 1.46       val_rh_mae: 9.267     val_nmae: 0.2096     val_p: 0.001261  
17    loss: 0.06658   val_loss: 0.06753   val_t_mae: 1.462      val_rh_mae: 9.212     val_nmae: 0.209      val_p: 0.001218  
18    loss: 0.06653   val_loss: 0.06873   val_t_mae: 1.451      val_rh_mae: 9.27      val_nmae: 0.2096     val_p: 0.002349  
19    loss: 0.06648   val_loss: 0.0678    val_t_mae: 1.461      val_rh_mae: 9.225     val_nmae: 0.2092     val_p: 0.001326  
20    loss: 0.06644   val_loss: 0.06859   val_t_mae: 1.449      val_rh_mae: 9.243     val_nmae: 0.2091     val_p: 0.002241  
In [ ]:
#@title A possible solution for Q5 (1/3)

torch.set_num_threads(2)

# Unconstrained neural network
unconstrained_NN = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=False
    ).to("cuda:0")
loss = MultiTaskLoss(alpha=0.0).to("cuda:0")

unconstrained_NN = fit(unconstrained_NN, loss, train_dl, val_dl, 
                       max_epochs=MAX_EPOCHS,
                       max_patience=PATIENCE)
In [ ]:
#@title A possible solution for Q5 (2/3)

# Architecture-constrained neural network
architecture_constrained_NN = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=True
    ).to("cuda:0")
loss = MultiTaskLoss(alpha=0.0).to("cuda:0")

architecture_constrained_NN = fit(architecture_constrained_NN, loss, 
                                  train_dl, val_dl,
                                  max_epochs=MAX_EPOCHS,
                                  max_patience=PATIENCE)
In [ ]:
#@title A possible solution for Q5 (3/3)

# Loss-constrained neural network
loss_constrained_NN = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=False
    ).to("cuda:0")
constrained_loss = MultiTaskLoss(alpha=ALPHA).to("cuda:0")

loss_constrained_NN = fit(loss_constrained_NN, constrained_loss,
                          train_dl, val_dl,
                          max_epochs=MAX_EPOCHS,
                          max_patience=PATIENCE)
In [ ]:
#@title If you are running out of time, we provide the weights and biases of trained models, which you may directly load by running the cell below

# Unconstrained neural network. Using a different name to prevent overwriting.
unconstrained_NN_provided = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=False
    ).to("cuda:0")
unconstrained_NN_provided.load_state_dict(torch.load(u_open))

# For more custom models, load the entire model instead of the state_dict
# Architecture-constrained neural network. Using a different name to prevent overwriting.
architecture_constrained_NN_provided = torch.load(a_open)

# Loss-constrained neural network. Using a different name to prevent overwriting.
loss_constrained_NN_provided = Net(
    N_INPUTS, N_STATIONS, EMBEDDING_SIZE, UNITS_L1, UNITS_L2, constraint=False
    ).to("cuda:0")
loss_constrained_NN_provided.load_state_dict(torch.load(l_open))

Note: The models that are already trained use different names to avoid overwriting the models you trained:

  • unconstrained_NN_provided for the unconstrained neural network
  • architecture_constrained_NN_provided for the architecture-constrained neural network
  • loss_constrained_NN_provided for the loss-constrained neural network

Part VI. Model evaluation¶

In [61]:
# Convert the test dataset to Torch
test = Data(test_x, test_y)

# Build a test data loader/generator
test_dl = DataLoader(test, batch_size=BATCH_SIZE * 16)
In [62]:
#@title Source code for the mode evaluation function
def evaluate(model, test_dl):
    model.eval()
    mae = 0.0
    with torch.no_grad():
        iterator = test_dl
        for X, y in iterator:
            pred = model(*X)
            mae += torch.mean(torch.abs(pred - y), dim=0)
    return mae / len(test_dl)

Q6) Evaluate the performance and physical consistency of each neural network type (unconstrained, architecture-constrained, and loss-constrained).¶

What are the benefits and drawbacks of each architecture/loss?

Hint: You may call the evaluate function, which takes two arguments:

  1. The machine learning model model
  2. The data loader of the dataset you would like to use for evaluation.

It outputs the mean absolute error for each target variable: air temperature, dew point temperature, surface air pressure, relative humidity, and water vapor mixing ratio.

In [63]:
# Evaluate the unconstrained network
evaluate(unconstrained_NN, test_dl)
Out[63]:
tensor([1.5021, 1.6904, 1.0258, 9.3728, 0.6128], device='cuda:0')
In [64]:
# Evaluate the architecture-constrained network
evaluate(architecture_constrained_NN, test_dl)
Out[64]:
tensor([1.5108, 1.6561, 1.0055, 9.2862, 0.5842], device='cuda:0')
In [65]:
# Evaluate the loss-constrained network
evaluate(loss_constrained_NN, test_dl)
Out[65]:
tensor([1.4991, 1.6445, 1.0089, 9.3269, 0.5827], device='cuda:0')
In [66]:
print()
In [67]:
#@title A possible solution for Q6
print('The unconstrained NN performance is:')
print(evaluate(unconstrained_NN, test_dl),'\n')

print('The architecture-constrained NN performance is:')
print(evaluate(architecture_constrained_NN, test_dl),'\n')

print('The loss-constrained NN performance is:')
print(evaluate(loss_constrained_NN, test_dl),'\n')

print('The raw performance of the unconstrained NN can be slightly better ')
print('than the raw performance of the architecture-constrained NN.\n')

print('However, this comes at the cost of physical consistency as ')
print('the unconstrained NN takes "shortcuts" and violates physical laws.\n')

print('Architecture-constrained NNs enforce thermodynamic state equations')
print('to within machine precision without a noticeable trade-off in performance.\n')

print('Loss-constrained NNs are easy to implement but it can be difficult ')
print('to choose the right value for the hyperparameter alpha, as ')
print('the loss-constrained implementation always leads to ')
print('a trade-off between performance and physical consistency.')
The unconstrained NN performance is:
tensor([1.5021, 1.6904, 1.0258, 9.3728, 0.6128], device='cuda:0') 

The architecture-constrained NN performance is:
tensor([1.5108, 1.6561, 1.0055, 9.2862, 0.5842], device='cuda:0') 

The loss-constrained NN performance is:
tensor([1.4991, 1.6445, 1.0089, 9.3269, 0.5827], device='cuda:0') 

The raw performance of the unconstrained NN can be slightly better 
than the raw performance of the architecture-constrained NN.

However, this comes at the cost of physical consistency as 
the unconstrained NN takes "shortcuts" and violates physical laws.

Architecture-constrained NNs enforce thermodynamic state equations
to within machine precision without a noticeable trade-off in performance.

Loss-constrained NNs are easy to implement but it can be difficult 
to choose the right value for the hyperparameter alpha, as 
the loss-constrained implementation always leads to 
a trade-off between performance and physical consistency.

Marcell_Faber_Stormy_Rainbow_View_Locarno-min.jpeg

🌈 Phew, after your successful physically-constrained, deep learning forecast, the storm is over. But hopefully, your quest for knowledge isn't 📚 Stay tuned for this session's second exercise, in which you will learn how to use physical knowledge, this time to make a neural network generalize over vastly different climates!

Source: Photo by Marcell Faber licensed under the Adobe Stock standard license